{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rank Classification using BERT on Amazon Review dataset\n",
    "\n",
    "## Introduction\n",
    "\n",
    "In this tutorial, you learn how to train a rank classification model using [Transfer Learning](https://en.wikipedia.org/wiki/Transfer_learning). We will use a pretrained DistilBert model to train on the Amazon review dataset.\n",
    "\n",
    "## About the dataset and model\n",
    "\n",
    "[Amazon Customer Review dataset](https://s3.amazonaws.com/amazon-reviews-pds/readme.html) consists of all different valid reviews from amazon.com. We will use the \"Digital_software\" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[[1]](https://arxiv.org/abs/1910.01108) model. It's a light-weight BERT model already trained on [Wikipedia text corpora](https://en.wikipedia.org/wiki/List_of_text_corpora), a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).\n",
    "\n",
    "<img src=\"https://djl-ai.s3.amazonaws.com/resources/images/amazon_review.png\" width=\"500\">\n",
    "<center>Amazon Review example</center>\n",
    "\n",
    "We will use review body as our data input and ranking as label.\n",
    "\n",
    "\n",
    "## Pre-requisites\n",
    "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n",
    "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n",
    "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n",
    "\n",
    "\n",
    "## Getting started\n",
    "Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.12.0\n",
    "%maven ai.djl:basicdataset:0.12.0\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "        \n",
    "// See https://github.com/deepjavalibrary/djl/blob/master/engines/mxnet/mxnet-engine/README.md\n",
    "// MXNet \n",
    "%maven ai.djl.mxnet:mxnet-model-zoo:0.12.0\n",
    "%maven ai.djl.mxnet:mxnet-native-auto:1.8.0\n",
    "\n",
    "// PyTorch\n",
    "// %maven ai.djl.pytorch:pytorch-model-zoo:0.12.0\n",
    "// %maven ai.djl.pytorch:pytorch-native-auto:1.8.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's import the necessary modules:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.*;\n",
    "import ai.djl.basicdataset.tabular.*;\n",
    "import ai.djl.basicdataset.utils.*;\n",
    "import ai.djl.engine.*;\n",
    "import ai.djl.inference.*;\n",
    "import ai.djl.metric.*;\n",
    "import ai.djl.modality.*;\n",
    "import ai.djl.modality.nlp.*;\n",
    "import ai.djl.modality.nlp.bert.*;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.ndarray.types.*;\n",
    "import ai.djl.nn.*;\n",
    "import ai.djl.nn.core.*;\n",
    "import ai.djl.nn.norm.*;\n",
    "import ai.djl.repository.zoo.*;\n",
    "import ai.djl.training.*;\n",
    "import ai.djl.training.dataset.*;\n",
    "import ai.djl.training.evaluator.*;\n",
    "import ai.djl.training.listener.*;\n",
    "import ai.djl.training.loss.*;\n",
    "import ai.djl.training.util.*;\n",
    "import ai.djl.translate.*;\n",
    "import java.io.*;\n",
    "import java.nio.file.*;\n",
    "import java.util.*;\n",
    "import org.apache.commons.csv.*;\n",
    "\n",
    "System.out.println(\"You are using: \" + Engine.getInstance().getEngineName() + \" Engine\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Dataset\n",
    "\n",
    "First step is to prepare the dataset for training. Since the original data was in TSV format, we can use CSVDataset to be the dataset container. We will also need to specify how do we want to preprocess the raw data. For BERT model, the input data are required to be tokenized and mapped into indices based on the inputs. In DJL, we defined an interface called Fearurizer, it is designed to allow user customize operation on each selected row/column of a dataset. In our case, we would like to clean and tokenize our sentencies. So let's try to implement it to deal with customer review sentencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final class BertFeaturizer implements CsvDataset.Featurizer {\n",
    "\n",
    "    private final BertFullTokenizer tokenizer;\n",
    "    private final int maxLength; // the cut-off length\n",
    "\n",
    "    public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {\n",
    "        this.tokenizer = tokenizer;\n",
    "        this.maxLength = maxLength;\n",
    "    }\n",
    "\n",
    "    /** {@inheritDoc} */\n",
    "    @Override\n",
    "    public void featurize(DynamicBuffer buf, String input) {\n",
    "        DefaultVocabulary vocab = tokenizer.getVocabulary();\n",
    "        // convert sentence to tokens (toLowerCase for uncased model)\n",
    "        List<String> tokens = tokenizer.tokenize(input.toLowerCase());\n",
    "        // trim the tokens to maxLength\n",
    "        tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;\n",
    "        // BERT embedding convention \"[CLS] Your Sentence [SEP]\"\n",
    "        buf.put(vocab.getIndex(\"[CLS]\"));\n",
    "        tokens.forEach(token -> buf.put(vocab.getIndex(token)));\n",
    "        buf.put(vocab.getIndex(\"[SEP]\"));\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we got this part done, we can apply the `BertFeaturizer` into our Dataset. We take `review_body` column and apply the Featurizer. We also pick `star_rating` as our label set. Since we go for batch input, we need to tell the dataset to pad our data if it is less than the `maxLength` we defined. `PaddingStackBatchifier` will do the work for you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {\n",
    "    String amazonReview =\n",
    "            \"https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz\";\n",
    "    float paddingToken = tokenizer.getVocabulary().getIndex(\"[PAD]\");\n",
    "    return CsvDataset.builder()\n",
    "            .optCsvUrl(amazonReview) // load from Url\n",
    "            .setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format\n",
    "            .setSampling(batchSize, true) // make sample size and random access\n",
    "            .optLimit(limit)\n",
    "            .addFeature(\n",
    "                    new CsvDataset.Feature(\n",
    "                            \"review_body\", new BertFeaturizer(tokenizer, maxLength)))\n",
    "            .addLabel(\n",
    "                    new CsvDataset.Feature(\n",
    "                            \"star_rating\", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))\n",
    "            .optDataBatchifier(\n",
    "                    PaddingStackBatchifier.builder()\n",
    "                            .optIncludeValidLengths(false)\n",
    "                            .addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))\n",
    "                            .build()) // define how to pad dataset to a fix length\n",
    "            .build();\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct your model\n",
    "\n",
    "We will load our pretrained model and prepare the classification. First construct the `criteria` to specify where to load the embedding (DistiledBERT), then call `loadModel` to download that embedding with pre-trained weights. Since this model is built without classification layer, we need to add a classification layer to the end of the model and train it. After you are done modifying the block, set it back to model using `setBlock`.\n",
    "\n",
    "### Load the word embedding\n",
    "\n",
    "We will download our word embedding and load it to memory (this may take a while)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// MXNet base model\n",
    "String modelUrls = \"https://resources.djl.ai/test-models/distilbert.zip\";\n",
    "if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n",
    "    modelUrls = \"https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip\";\n",
    "}\n",
    "\n",
    "Criteria<NDList, NDList> criteria = Criteria.builder()\n",
    "        .optApplication(Application.NLP.WORD_EMBEDDING)\n",
    "        .setTypes(NDList.class, NDList.class)\n",
    "        .optModelUrls(modelUrls)\n",
    "        .optProgress(new ProgressBar())\n",
    "        .build();\n",
    "ZooModel<NDList, NDList> embedding = criteria.loadModel();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create classification layers\n",
    "\n",
    "Then let's build a simple MLP layer to classify the ranks. We set the output of last FullyConnected (Linear) layer to 5 to get the predictions for star 1 to 5. Then all we need to do is to load the block into the model. Before applying the classification layer, we also need to add text embedding to the front. In our case, we just create a Lambda function that do the followings:\n",
    "\n",
    "1. batch_data (batch size, token indices) -> batch_data + max_length (size of the token indices)\n",
    "2. generate embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Predictor<NDList, NDList> embedder = embedding.newPredictor();\n",
    "Block classifier = new SequentialBlock()\n",
    "        // text embedding layer\n",
    "        .add(\n",
    "            ndList -> {\n",
    "                NDArray data = ndList.singletonOrThrow();\n",
    "                NDList inputs = new NDList();\n",
    "                long batchSize = data.getShape().get(0);\n",
    "                float maxLength = data.getShape().get(1);\n",
    "\n",
    "                if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n",
    "                    inputs.add(data.toType(DataType.INT64, false));\n",
    "                    inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));\n",
    "                    inputs.add(data.getManager().arange(maxLength)\n",
    "                               .toType(DataType.INT64, false)\n",
    "                               .broadcast(data.getShape()));\n",
    "                } else {\n",
    "                    inputs.add(data);\n",
    "                    inputs.add(data.getManager().full(new Shape(batchSize), maxLength));\n",
    "                }\n",
    "                // run embedding\n",
    "                try {\n",
    "                    return embedder.predict(inputs);\n",
    "                } catch (TranslateException e) {\n",
    "                    throw new IllegalArgumentException(\"embedding error\", e);\n",
    "                }\n",
    "            })\n",
    "        // classification layer\n",
    "        .add(Linear.builder().setUnits(768).build()) // pre classifier\n",
    "        .add(Activation::relu)\n",
    "        .add(Dropout.builder().optRate(0.2f).build())\n",
    "        .add(Linear.builder().setUnits(5).build()) // 5 star rating\n",
    "        .addSingleton(nd -> nd.get(\":,0\")); // Take [CLS] as the head\n",
    "Model model = Model.newInstance(\"AmazonReviewRatingClassification\");\n",
    "model.setBlock(classifier);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Training\n",
    "\n",
    "Finally, we can start building our training pipeline to train the model.\n",
    "\n",
    "### Creating Training and Testing dataset\n",
    "\n",
    "Firstly, we need to create a voabulary that is used to map token to index such as \"hello\" to 1121 (1121 is the index of \"hello\" in dictionary). Then we simply feed the vocabulary to the tokenizer that used to tokenize the sentence. Finally, we just need to split the dataset based on the ratio.\n",
    "\n",
    "Note: we set the cut-off length to 64 which means only the first 64 tokens from the review will be used. You can increase this value to achieve better accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Prepare the vocabulary\n",
    "DefaultVocabulary vocabulary = DefaultVocabulary.builder()\n",
    "        .addFromTextFile(embedding.getArtifact(\"vocab.txt\"))\n",
    "        .optUnknownToken(\"[UNK]\")\n",
    "        .build();\n",
    "// Prepare dataset\n",
    "int maxTokenLength = 64; // cutoff tokens length\n",
    "int batchSize = 8;\n",
    "int limit = Integer.MAX_VALUE;\n",
    "// int limit = 512; // uncomment for quick testing\n",
    "\n",
    "BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);\n",
    "CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);\n",
    "// split data with 7:3 train:valid ratio\n",
    "RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);\n",
    "RandomAccessDataset trainingSet = datasets[0];\n",
    "RandomAccessDataset validationSet = datasets[1];"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup Trainer and training config\n",
    "\n",
    "Then, we need to setup our trainer. We set up the accuracy and loss function. The model training logs will be saved to `build/modlel`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SaveModelTrainingListener listener = new SaveModelTrainingListener(\"build/model\");\n",
    "        listener.setSaveModelCallback(\n",
    "            trainer -> {\n",
    "                TrainingResult result = trainer.getTrainingResult();\n",
    "                Model model = trainer.getModel();\n",
    "                // track for accuracy and loss\n",
    "                float accuracy = result.getValidateEvaluation(\"Accuracy\");\n",
    "                model.setProperty(\"Accuracy\", String.format(\"%.5f\", accuracy));\n",
    "                model.setProperty(\"Loss\", String.format(\"%.5f\", result.getValidateLoss()));\n",
    "            });\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type\n",
    "        .addEvaluator(new Accuracy())\n",
    "        .optDevices(Device.getDevices(1)) // train using single GPU\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging(\"build/model\"))\n",
    "        .addTrainingListeners(listener);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start training\n",
    "\n",
    "We will start our training process. Training on GPU will takes approximately 10 mins. For CPU, it will take more than 2 hours to finish."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int epoch = 2;\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);\n",
    "trainer.setMetrics(new Metrics());\n",
    "Shape encoderInputShape = new Shape(batchSize, maxTokenLength);\n",
    "// initialize trainer with proper input shape\n",
    "trainer.initialize(encoderInputShape);\n",
    "EasyTrain.fit(trainer, epoch, trainingSet, validationSet);\n",
    "System.out.println(trainer.getTrainingResult());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save(Paths.get(\"build/model\"), \"amazon-review.param\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Verify the model\n",
    "\n",
    "We can create a predictor from the model to run inference on our customized dataset. Firstly, we can create a `Translator` for the model to do preprocessing and post processing. Similar to what we have done before, we need to tokenize the input sentence and get the output ranking."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyTranslator implements Translator<String, Classifications> {\n",
    "\n",
    "    private BertFullTokenizer tokenizer;\n",
    "    private DefaultVocabulary vocab;\n",
    "    private List<String> ranks;\n",
    "\n",
    "    public MyTranslator(BertFullTokenizer tokenizer) {\n",
    "        this.tokenizer = tokenizer;\n",
    "        vocab = tokenizer.getVocabulary();\n",
    "        ranks = Arrays.asList(\"1\", \"2\", \"3\", \"4\", \"5\");\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Batchifier getBatchifier() { return new StackBatchifier(); }\n",
    "\n",
    "    @Override\n",
    "    public NDList processInput(TranslatorContext ctx, String input) {\n",
    "        List<String> tokens = tokenizer.tokenize(input);\n",
    "        float[] indices = new float[tokens.size() + 2];\n",
    "        indices[0] = vocab.getIndex(\"[CLS]\");\n",
    "        for (int i = 0; i < tokens.size(); i++) {\n",
    "            indices[i+1] = vocab.getIndex(tokens.get(i));\n",
    "        }\n",
    "        indices[indices.length - 1] = vocab.getIndex(\"[SEP]\");\n",
    "        return new NDList(ctx.getNDManager().create(indices));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Classifications processOutput(TranslatorContext ctx, NDList list) {\n",
    "        return new Classifications(ranks, list.singletonOrThrow().softmax(0));\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can create a `Predictor` to run the inference. Let's try with a random customer review:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String review = \"It works great, but it takes too long to update itself and slows the system\";\n",
    "Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));\n",
    "System.out.println(predictor.predict(review));"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
